import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.manifold import TSNE
# from tsnecuda import TSNE

action_names = ['Turn left', 'Turn right', 'Forward', 'Left', 'Right', 'Forward left', 'Forward right']
# 创建自定义的颜色映射
cmap = ListedColormap(plt.get_cmap('Pastel2').colors[:len(action_names)])


def read_and_visual(file_path):
    """读取缓冲区文件并进行可视化"""
    representation_buffers = torch.load(file_path)
    tsne = TSNE(n_components=2, perplexity=30, random_state=42)
    # all_enbeddings = np.concatenate([embeddings for embeddings, _ in representation_buffers], axis=0)
    # all_embeddings_2d = tsne.fit_transform(all_enbeddings)
    for i, (embeddings, actions) in enumerate(representation_buffers):
        num_actions = actions.max().item() + 1
        # 统计每类动作的样本数量
        action_counter = torch.bincount(actions, minlength=len(action_names))
        print(f"Action distribution in buffer {i}: {action_counter}")

        # 每类动作只保留1000个样本进行可视化
        indices = []
        for action in range(num_actions):
            indices.extend(np.where(actions == action)[0][:1000])
        embeddings = embeddings[indices]
        actions = actions[indices]

        embeddings_2d = tsne.fit_transform(embeddings)
        # embeddings_2d = all_embeddings_2d[:embeddings.shape[0]]
        # 动态调整颜色映射
        dynamic_cmap = ListedColormap(cmap.colors[:num_actions])
        # 使用 matplotlib 进行可视化
        fig = plt.figure(figsize=(7, 6))
        sc = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=actions, cmap=dynamic_cmap)
        cbar = plt.colorbar(sc, ticks=range(num_actions))
        cbar.ax.set_yticklabels(action_names[:num_actions], fontsize=12)
        # 隐藏坐标轴标签，但保留坐标轴
        ax = plt.gca()
        ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False,
                       labelbottom=False,
                       labelleft=False)
        plt.savefig(f"../../action_representations_{i}.svg")
        plt.show()


if __name__ == '__main__':
    read_and_visual("../../02representation_buffers.pkl")